{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 融合 Conv2dReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tvm\n",
    "from tvm import relax\n",
    "from tvm.relax.backend.cuda.cublas import partition_for_cublas\n",
    "from tvm.relax.backend.cuda.cutlass import partition_for_cutlass\n",
    "from tvm.relax.dpl.pattern import (\n",
    "    is_op,\n",
    "    is_tuple_get_item,\n",
    "    make_fused_bias_activation_pattern,\n",
    "    wildcard,\n",
    ")\n",
    "from tvm.relax.transform import PatternCheckContext\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Conv2dReLU:\n",
    "    @R.function\n",
    "    def main(\n",
    "        data: R.Tensor((1, 64, 56, 56), \"float32\"),\n",
    "        weight1: R.Tensor((64, 64, 3, 3), \"float32\"),\n",
    "    ):\n",
    "        with R.dataflow():\n",
    "            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))\n",
    "            R.output(conv1)\n",
    "\n",
    "        return conv1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, weight1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv1\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Conv2dReLU.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv2d_relu_pat = make_fused_bias_activation_pattern(\"relax.nn.conv2d\", activation=\"relax.nn.relu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 绑定参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_np = np.random.randn(64, 64, 3, 3).astype(\"float32\")\n",
    "mod = tvm.transform.Sequential(\n",
    "    [\n",
    "        relax.transform.BindParams(\"main\", {\"weight1\": weight_np}),\n",
    "        relax.transform.FuseOpsByPattern(\n",
    "            [(\"dnnl.conv2d_relu\", conv2d_relu_pat)], bind_constants=True\n",
    "        ),\n",
    "    ]\n",
    ")(Conv2dReLU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function(private<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>, <span style=\"color: #BA2121\">&quot;Primitive&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu(data)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert \"fused_relax_nn_conv2d_relax_nn_relu\" in [var.name_hint for var in mod.functions.keys()]\n",
    "\n",
    "for gvar, f in mod.functions.items():\n",
    "    if gvar.name_hint == \"fused_relax_nn_conv2d_relax_nn_relu\":\n",
    "        conv2d = f.body.blocks[0].bindings[0].value\n",
    "        assert isinstance(conv2d.args[1], relax.Constant)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `FuseOpsByPattern`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "patterns = [(\"dnnl.conv2d_relu\", conv2d_relu_pat)]\n",
    "bind_constants = True\n",
    "annotate_codegen = True\n",
    "partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu_dnnl</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(data_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data_1, weight1_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(data, weight1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "partitioned.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "patterns = []\n",
    "bind_constants = True\n",
    "annotate_codegen = True\n",
    "partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, weight1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv1\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "partitioned.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, weight1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv1\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Conv2dReLU.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 融合 `Conv2dReLUx2`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Conv2dReLUx2:\n",
    "    @R.function\n",
    "    def main(\n",
    "        data: R.Tensor((1, 64, 56, 56), \"float32\"),\n",
    "        weight1: R.Tensor((64, 64, 3, 3), \"float32\"),\n",
    "        weight2: R.Tensor((64, 64, 3, 3), \"float32\"),\n",
    "    ):\n",
    "        with R.dataflow():\n",
    "            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))\n",
    "            conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))\n",
    "            R.output(conv2)\n",
    "\n",
    "        return conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, weight1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(conv1, weight2, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv2)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv2\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Conv2dReLUx2.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "patterns = [(\"dnnl.conv2d_relu\", conv2d_relu_pat)]\n",
    "bind_constants = True\n",
    "annotate_codegen = True\n",
    "partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu1_dnnl</span>(conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(conv1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(conv1_1, weight2_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv2)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(conv1, weight2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu_dnnl</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(data_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data_1, weight1_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(data, weight1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu1_dnnl(lv, weight2)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "partitioned.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模板匹配依赖顺序："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv2d_pat = make_fused_bias_activation_pattern(\"relax.nn.conv2d\", activation=None)\n",
    "patterns =  [(\"dnnl.conv2d\", conv2d_pat), (\"dnnl.conv2d_relu\", conv2d_relu_pat)]\n",
    "bind_constants = True\n",
    "annotate_codegen = True\n",
    "partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 变换器会优先尝试匹配第一个模式 `conv2d_pat`\n",
    "- 即使存在更复杂的 `conv2d_relu` 模式，由于顺序优先，会先应用简单的 `conv2d` 融合\n",
    "- 最终结果中只有 `conv2d` 算子被融合，ReLU 算子保持独立"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d1_dnnl</span>(conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(conv1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(conv1_1, weight2_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(conv1, weight2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_dnnl</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(data_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data_1, weight1_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(data, weight1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_dnnl(data, weight1)\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d1_dnnl(conv1, weight2)\n",
       "            conv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv1)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv2\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "partitioned.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 融合 `Conv2dConv2dReLU`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Conv2dConv2dReLU:\n",
    "    @R.function\n",
    "    def main(\n",
    "        data: R.Tensor((1, 64, 56, 56), \"float32\"),\n",
    "        weight1: R.Tensor((64, 64, 3, 3), \"float32\"),\n",
    "        weight2: R.Tensor((64, 64, 3, 3), \"float32\"),\n",
    "    ):\n",
    "        with R.dataflow():\n",
    "            conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))\n",
    "            conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))\n",
    "            R.output(conv2d)\n",
    "\n",
    "        return conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data, weight1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(conv1, weight2, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            conv2d: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv1)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(conv2d)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> conv2d\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Conv2dConv2dReLU.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv2d_pat = make_fused_bias_activation_pattern(\"relax.nn.conv2d\", activation=None)\n",
    "patterns =  [(\"dnnl.conv2d_relu\", conv2d_relu_pat), (\"dnnl.conv2d\", conv2d_pat)]\n",
    "bind_constants = True\n",
    "annotate_codegen = True\n",
    "partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu1_dnnl</span>(conv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(conv1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(conv1_1, weight2_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv2)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(conv1, weight2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_nn_conv2d_relax_nn_relu_dnnl</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(data_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;dnnl.conv2d_relu&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(data_1, weight1_1, strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(data, weight1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_nn_conv2d_relax_nn_relu1_dnnl(lv, weight2)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "partitioned.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
